{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read and Decode BBCI Data with Start-Stop-Markers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial shows how to read and decode BBCI data with start and stop markers. The data loading part is more complicated and it is advised to read the other tutorials before."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup logging to see outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import sys\n",
    "logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',\n",
    "                     level=logging.DEBUG, stream=sys.stdout)\n",
    "log = logging.getLogger()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and preprocess data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a bit more complicated than before since we have to add breaks etc. Here I now opt to add breaks do all preprocessings per run and only later combine them together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from braindecode.datautil.splitters import concatenate_sets\n",
    "from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne, add_breaks\n",
    "from braindecode.datasets.bbci import load_bbci_sets_from_folder\n",
    "from copy import deepcopy\n",
    "from braindecode.mne_ext.signalproc import resample_cnt, mne_apply\n",
    "from braindecode.datautil.signalproc import lowpass_cnt\n",
    "from braindecode.datautil.signalproc import exponential_running_standardize\n",
    "\n",
    "def create_cnts(folder, runs,):\n",
    "    # Load data\n",
    "    cnts = load_bbci_sets_from_folder(folder, runs)\n",
    "    \n",
    "    # Now do some preprocessings:\n",
    "    # Resampling to 250 Hz, lowpass below 38, eponential standardization\n",
    "    \n",
    "    new_cnts = []\n",
    "    for cnt in cnts:\n",
    "        # Only take some channels \n",
    "        #cnt = cnt.drop_channels(['STI 014']) # This would remove stimulus channel\n",
    "        cnt = cnt.pick_channels(['C3', 'CPz', 'C4'])\n",
    "        log.info(\"Preprocessing....\")\n",
    "        cnt = mne_apply(lambda a: lowpass_cnt(a, 38,cnt.info['sfreq'], axis=1), cnt)\n",
    "        cnt = resample_cnt(cnt, 250)\n",
    "        # mne apply will apply the function to the data (a 2d-numpy-array)\n",
    "        # have to transpose data back and forth, since\n",
    "        # exponential_running_standardize expects time x chans order\n",
    "        # while mne object has chans x time order\n",
    "        cnt = mne_apply(lambda a: exponential_running_standardize(\n",
    "            a.T, init_block_size=1000,factor_new=0.001, eps=1e-4).T,\n",
    "            cnt)\n",
    "        new_cnts.append(cnt)\n",
    "    return new_cnts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "\n",
    "train_runs = [1,2,3]\n",
    "train_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R01-8/', \n",
    "                         train_runs,)\n",
    "\n",
    "name_to_start_code = OrderedDict([('Right Hand', 1), ('Feet', 4),\n",
    "            ('Rotation', 8), ('Words', [10])])\n",
    "\n",
    "name_to_stop_code = OrderedDict([('Right Hand', [20,21,22,23,24,28,30]),\n",
    "        ('Feet', [20,21,22,23,24,28,30]),\n",
    "        ('Rotation', [20,21,22,23,24,28,30]), \n",
    "        ('Words', [20,21,22,23,24,28,30])])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_runs = [9,10]\n",
    "test_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R09-10/', test_runs,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We already create the model now, since we need to know the receptive field size for properly cutting out the data to predict. We need to cut out data starting at -receptive_field_size samples before the first sample we want to predict."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n",
    "from torch import nn\n",
    "from braindecode.torch_ext.util import set_random_seeds\n",
    "from braindecode.models.util import to_dense_prediction_model\n",
    "\n",
    "# Set if you want to use GPU\n",
    "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n",
    "cuda = True\n",
    "set_random_seeds(seed=20170629, cuda=cuda)\n",
    "\n",
    "# This will determine how many crops are processed in parallel\n",
    "input_time_length = 650\n",
    "in_chans = train_cnts[0].get_data().shape[0]\n",
    "# final_conv_length determines the size of the receptive field of the ConvNet\n",
    "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=5, input_time_length=input_time_length,\n",
    "                        final_conv_length=29).create_network()\n",
    "to_dense_prediction_model(model)\n",
    "\n",
    "if cuda:\n",
    "    model.cuda()\n",
    "from braindecode.torch_ext.util import np_to_var\n",
    "import numpy as np\n",
    "# determine output size\n",
    "test_input = np_to_var(np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))\n",
    "if cuda:\n",
    "    test_input = test_input.cuda()\n",
    "out = model(test_input)\n",
    "n_preds_per_input = out.cpu().data.numpy().shape[2]\n",
    "print(\"{:d} predictions per input/trial\".format(n_preds_per_input))\n",
    "n_receptive_field = input_time_length - n_preds_per_input\n",
    "receptive_field_ms = n_receptive_field * 1000.0 / train_cnts[0].info['sfreq']\n",
    "print(\"Receptive field: {:d}/{:.2f} (samples/ms)\".format(n_receptive_field,\n",
    "                                                      receptive_field_ms))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create SignalAndTarget Sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from braindecode.datautil.trial_segment import create_signal_target_with_breaks_from_mne\n",
    "\n",
    "break_start_offset_ms = 1000\n",
    "break_stop_offset_ms = -500\n",
    "\n",
    "train_sets = [create_signal_target_with_breaks_from_mne(\n",
    "    cnt, name_to_start_code, [0,0], \n",
    "    name_to_stop_code, min_break_length_ms=1000, max_break_length_ms=10000,\n",
    "    break_epoch_ival_ms=[500,-500],\n",
    "    prepad_trials_to_n_samples=input_time_length) \n",
    "              for cnt in train_cnts]\n",
    "train_set = concatenate_sets(train_sets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sets = [create_signal_target_with_breaks_from_mne(\n",
    "    cnt, name_to_start_code, [0,0], \n",
    "    name_to_stop_code, min_break_length_ms=1000, max_break_length_ms=10000,\n",
    "    break_epoch_ival_ms=[500,-500],\n",
    "    prepad_trials_to_n_samples=input_time_length) \n",
    "              for cnt in test_cnts]\n",
    "test_set = concatenate_sets(test_sets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.datautil.splitters import split_into_two_sets\n",
    "\n",
    "train_set, valid_set = split_into_two_sets(train_set, first_set_fraction=0.8)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup optimizer and iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch import optim\n",
    "\n",
    "optimizer = optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.datautil.iterators import CropsFromTrialsIterator\n",
    "iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,\n",
    "                                  n_preds_per_input=n_preds_per_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Monitors, Loss function, Stop Criteria"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.experiments.experiment import Experiment\n",
    "from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, CroppedTrialMisclassMonitor, MisclassMonitor\n",
    "from braindecode.experiments.stopcriteria import MaxEpochs\n",
    "import torch.nn.functional as F\n",
    "import torch as th\n",
    "from braindecode.torch_ext.modules import Expression\n",
    "from braindecode.torch_ext.losses import log_categorical_crossentropy\n",
    "\n",
    "\n",
    "loss_function = log_categorical_crossentropy\n",
    "\n",
    "model_constraint = None\n",
    "monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),\n",
    "            CroppedTrialMisclassMonitor(input_time_length), RuntimeMonitor(),]\n",
    "stop_criterion = MaxEpochs(20)\n",
    "exp = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint,\n",
    "          monitors, stop_criterion, remember_best_column='valid_misclass',\n",
    "          run_after_early_stop=True, batch_modifier=None, cuda=cuda)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "exp.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We arrive at about 54% accuracy. With only 3 sensors and 3 training runs, we cannot get much better :)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
